# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from functools import reduce
from mindspore.ops import DataType
from mindspore.ops import TBERegOp
from mindspore.ops import op_info_register
from mindspore import dtype as mstype
from mindspore.ops import Custom

from accspeed.ops.dropout_mask_decoder.constants import DROPOUT_MASK_DECODER_KERNEL_NAME
from accspeed.ops.dropout_mask_decoder.dropout_mask_decoder import decode_dropout_mask

dropout_mask_decoder_op_info = (
    TBERegOp("DropoutMaskDecoderPrimitive")
    .fusion_type("OPAQUE")
    .partial_flag(True)
    .async_flag(False)
    .binfile_name(f"{DROPOUT_MASK_DECODER_KERNEL_NAME}.so")
    .compute_cost(10)
    .kernel_name(DROPOUT_MASK_DECODER_KERNEL_NAME)
    .input(0, "input_mask", False, "required", "all")
    .output(0, "res", False, "required", "all")
    .dtype_format(DataType.U8_Default, DataType.U8_Default)
    .get_op_info()
)


@op_info_register(dropout_mask_decoder_op_info)
def dropout_mask_decoder(input_mask, res, kernel_name=DROPOUT_MASK_DECODER_KERNEL_NAME):
    decode_dropout_mask(input_mask, res, kernel_name=kernel_name)


def get_dropout_mask_decoder(bsz):
    def infer_shape(input_mask_shape):
        """
        :param input_mask_shape: (bsz, head_dim * seq_len * seq_len / 8)
        :return: (bsz, head_dim * seq_len * seq_len)
        """
        input_bytes_num = reduce(lambda x, y: x * y, input_mask_shape)
        o_mask_bytes_num = input_bytes_num * 8
        o_mask_sec_dim = o_mask_bytes_num // bsz
        return bsz, o_mask_sec_dim

    def infer_dtype(input_mask_dtype):
        return mstype.uint8

    dropout_mask_decoder = Custom(decode_dropout_mask, out_shape=infer_shape, out_dtype=infer_dtype, func_type="tbe",
                                  bprop=None, reg_info=dropout_mask_decoder_op_info)
    dropout_mask_decoder.init_prim_io_names(
        inputs=["input_mask"],
        outputs=["res"]
    )
    return dropout_mask_decoder
